# -*- coding:utf-8 -*-

'''
得到一项集
'''


def get_distinct_item(data_set):
    list_tmp = []
    for list_data in data_set:
        for item in list_data:
            if not [item] in list_tmp:
                list_tmp.append([item])
    list_tmp.sort()
    result = list(map(frozenset, list_tmp))
    return result


'''
获得频繁项集及支持度
'''


def get_frequent(data_set, candidates, min_support=0.2):
    dict_tmp = {}
    support_res = {}
    res_frequent = []
    for i in data_set:
        for j in candidates:
            if j.issubset(i):
                if j not in dict_tmp:
                    dict_tmp[j] = 1
                else:
                    dict_tmp[j] += 1
    size_num = len(data_set)
    for key in dict_tmp:
        support = dict_tmp[key] / size_num
        if support >= min_support:
            res_frequent.insert(0, key)
        support_res[key] = support
    return res_frequent, support_res


'''
获取候选K项集
list_k的元素为set集合,且set集合中至少有1个元素
k从2开始,因为频繁一项式已经通过函数get_distinct_item获得
'''


def get_candidates_list(list_k, k):
    res_list = []
    size_num = len(list_k)
    for i in range(size_num):
        for j in range(i+1, size_num):
            list_1 = list(list_k[i])[:k-2]  # 获取当前集合中前k-2个元素
            list_2 = list(list_k[j])[:k-2]  # 获取当前集合的下一个集合中前k-2个元素
            list_1.sort()
            list_2.sort()
            if list_1 == list_2:
                res_list.append(list_k[i] | list_k[j])  # 通过|合并集合
    return res_list


def apriori(data_set, min_support=0.2):
    C1 = get_distinct_item(data_set)  # 获取一项集
    D = list(map(set, data_set)).copy()  # 获取需要分析的数据
    L1, supportData = get_frequent(D, C1, min_support)  # 根据最小支持度获取频繁一项集及每个元素的支持度
    L = [L1]  # 初始化列表,把k项集依次添加进来
    k = 2
    while len(L[k-2]) > 0:  # 处理新添加进来的候选k项集
        Ck = get_candidates_list(L[k-2], k)
        Lk, supK = get_frequent(D, Ck, min_support)
        supportData.update(supK)
        L.append(Lk)
        k += 1
    return L, supportData


def generate_rules(L, support_data, min_conf=0.3):
    res_list = []
    for i in range(1, len(L)):  # 从频繁二项集合开始
        for f in L[i]:
            H1 = [frozenset([x]) for x in f]
            if i > 1:
                get_rules_set(f, H1, support_data, res_list, min_conf)
            else:
                calculate_conf(f, H1, support_data, res_list, min_conf)
    return res_list


def calculate_conf(frequent, H, support_data, res_list, min_conf=0.3):
    pruned_list = []
    for h in H:
        conf = support_data[h] / support_data[frequent - h]
        if conf >= min_conf:
            print(str((frequent - h)) + "--->" + str(h) + " confidence:" + str(conf))
            res_list.append((frequent - h, h, conf))
            pruned_list.append(h)
    return pruned_list


def get_rules_set(frequent, H, support_data, res_list, min_conf=0.3):
    m = len(H[0])
    if len(frequent) > (m + 1):
        Hmp1 = get_candidates_list(H, m + 1)
        Hmp1 = calculate_conf(frequent, Hmp1, support_data, res_list, min_conf)
        if len(Hmp1) > 1:
            get_rules_set(frequent, Hmp1, support_data, res_list, min_conf)


# x1, y1 = apriori(data)
# print(x1)
# print(y1)
# x2 = generate_rules(x1, y1)
# print(x2)
